{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building Relation Network Using Tensorflow"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Relation network is pretty simple, right? Now, we will understand relation network better by implementing it in TensorFlow. We will consider a simple binary classification problem and see how can we solve them using a relation network. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we import all the required libraries,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us say we have two classes in our dataset, we will randomly generate some 1000 data points for each of these classes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "classA = np.random.rand(1000,18)\n",
    "\n",
    "ClassB = np.random.rand(1000,18)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Create our dataset by combining both of these classes,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "data = np.vstack([classA, ClassB])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we will set the label, we assign label 1 for classA and label 0 for classB. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "label = np.vstack([np.ones((len(classA),1)),np.zeros((len(ClassB),1))])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So, our dataset will have 2000 records"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2000, 18)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2000, 1)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "label.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will define the placeholders for our support set $x_i$ and query set $x_j$. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "xi = tf.placeholder(tf.float32, [None, 9])\n",
    "xj = tf.placeholder(tf.float32, [None, 9])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Define the placeholder for label y,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "y = tf.placeholder(tf.float32, [None, 1]) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we will define our embedding function $f_{\\varphi}$()  to learn the embeddings of support set and query set. We will use a normal feedforward network as our embedding function. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def embedding_function(x):\n",
    "    \n",
    "    weights = tf.Variable(tf.truncated_normal([9,1]))\n",
    "    bias = tf.Variable(tf.truncated_normal([1]))\n",
    "    \n",
    "    a = (tf.nn.xw_plus_b(x,weights,bias))\n",
    "    embeddings = tf.nn.relu(a)\n",
    "    \n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the embeddings of our support set i.e $f_{\\varphi}(x_i) $"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f_xi = embedding_function(xi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compute the embeddings of our query set i.e $f_{\\varphi}(x_j)$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "f_xj = embedding_function(xj)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we calculated the embeddings and have the feature vectors, we combine both the support set and query set feature vectors i.e $Z(f_{\\varphi}(x_i), f_{\\varphi}(x_j))$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "Z = tf.concat([f_xi,f_xj],axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define our relation function $g_{\\phi}()$ as three layered neural network with relu activations, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def relation_function(x):\n",
    "    w1 = tf.Variable(tf.truncated_normal([2,3]))\n",
    "    b1 = tf.Variable(tf.truncated_normal([3]))\n",
    "    \n",
    "    w2 = tf.Variable(tf.truncated_normal([3,5]))\n",
    "    b2 = tf.Variable(tf.truncated_normal([5]))\n",
    "    \n",
    "    w3 = tf.Variable(tf.truncated_normal([5,1]))\n",
    "    b3 = tf.Variable(tf.truncated_normal([1]))\n",
    "    \n",
    "    #layer1\n",
    "    z1 = (tf.nn.xw_plus_b(x,w1,b1))\n",
    "    a1 = tf.nn.relu(z1)\n",
    "    \n",
    "    #layer2\n",
    "    z2 = tf.nn.xw_plus_b(a1,w2,b2)\n",
    "    a2 = tf.nn.relu(z2)\n",
    "    \n",
    "    #layer3\n",
    "    z3 = tf.nn.xw_plus_b(z2,w3,b3)\n",
    "\n",
    "    #output\n",
    "    y = tf.nn.sigmoid(z3)\n",
    "    \n",
    "    return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now pass the concatenated feature vectors of support set and query set to the relation function and get the relation scores, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "relation_scores = relation_function(Z)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We define our loss function as mean squared error i.e squared difference between relation scores and actual y value. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "loss_function = tf.reduce_mean(tf.squared_difference(relation_scores,y))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can minimize the loss using adam optimizer,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "optimizer = tf.train.AdamOptimizer(0.1)\n",
    "train = optimizer.minimize(loss_function)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, let's start our tensorflow session"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()\n",
    "sess.run(tf.global_variables_initializer())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now, we randomly sample data points for our support set $x_i$ and query set $x_j$ and train the network, "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Episode 0: loss 0.498 \n",
      "Episode 100: loss 0.250 \n",
      "Episode 200: loss 0.248 \n",
      "Episode 300: loss 0.248 \n",
      "Episode 400: loss 0.247 \n",
      "Episode 500: loss 0.248 \n",
      "Episode 600: loss 0.248 \n",
      "Episode 700: loss 0.248 \n",
      "Episode 800: loss 0.247 \n",
      "Episode 900: loss 0.248 \n"
     ]
    }
   ],
   "source": [
    "for episode in range(1000):\n",
    "    _, loss_value = sess.run([train, loss_function], \n",
    "                             feed_dict={xi:data[:,0:9]+np.random.randn(*np.shape(data[:,0:9]))*0.05,\n",
    "                                        xj:data[:,9:]+np.random.randn(*np.shape(data[:,9:]))*0.05,\n",
    "                                        y:label})\n",
    "    if episode % 100 == 0:\n",
    "        print(\"Episode {}: loss {:.3f} \".format(episode, loss_value))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [default]",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
